/*
 *    This program is free software; you can redistribute it and/or modify
 *    it under the terms of the GNU General Public License as published by
 *    the Free Software Foundation; either version 2 of the License, or
 *    (at your option) any later version.
 *
 *    This program is distributed in the hope that it will be useful,
 *    but WITHOUT ANY WARRANTY; without even the implied warranty of
 *    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 *    GNU General Public License for more details.
 *
 *    You should have received a copy of the GNU General Public License
 *    along with this program; if not, write to the Free Software
 *    Foundation, Inc., 675 Mass Ave, Cambridge, MA 02139, USA.
 */

/*
 *    ResidualSplit.java
 *    Copyright (C) 2003 Niels Landwehr
 *
 */

package weka.classifiers.trees.lmt;

import java.io.*;
import java.util.*;
import weka.core.*;
import weka.classifiers.*;
import weka.classifiers.trees.j48.*;
import weka.classifiers.functions.*;
import weka.filters.unsupervised.attribute.Remove;
import weka.filters.Filter;
import weka.filters.supervised.attribute.NominalToBinary;

/**
 * Helper class for logistic model trees (weka.classifiers.trees.lmt.LMT) to implement the 
 * splitting criterion based on residuals of the LogitBoost algorithm.
 * 
 * @author Niels Landwehr
 * @version $Revision: 1.1 $
 */

public class ResidualSplit extends ClassifierSplitModel{
    
    /**The attribute selected for the split*/
    protected Attribute m_attribute;

    /**The index of the attribute selected for the split*/
    protected int m_attIndex;
    
    /**Number of instances in the set*/
    protected int m_numInstances;

    /**Number of classed*/
    protected int m_numClasses;

    /**The set of instances*/
    protected Instances m_data;
    
    /**The Z-values (LogitBoost response) for the set of instances*/
    protected double[][] m_dataZs;

    /**The LogitBoost-weights for the set of instances*/
    protected double[][] m_dataWs; 

    /**The split point (for numeric attributes)*/
    protected double m_splitPoint;

    /**
     *Creates a split object
     *@param attIndex the index of the attribute to split on 
     */    
    public ResidualSplit(int attIndex) {	
	m_attIndex = attIndex;              
    }
        
    /**
     * Builds the split.
     * Needs the Z/W values of LogitBoost for the set of instances.
     */
    public void buildClassifier(Instances data, double[][] dataZs, double[][] dataWs) 
	throws Exception {
	
	m_numClasses = data.numClasses();	
	m_numInstances = data.numInstances();
	if (m_numInstances == 0) throw new Exception("Can't build split on 0 instances");
	
	//save data/Zs/Ws
	m_data = data;
	m_dataZs = dataZs;
	m_dataWs = dataWs;
	m_attribute = data.attribute(m_attIndex);

	//determine number of subsets and split point for numeric attributes
	if (m_attribute.isNominal()) {
	    m_splitPoint = 0.0;
	    m_numSubsets = m_attribute.numValues();
	} else {
	    getSplitPoint();
	    m_numSubsets = 2;
	}
	//create distribution for data
	m_distribution = new Distribution(data, this);	
    }
    
    
    /**
     * Selects split point for numeric attribute.
     */
    protected boolean getSplitPoint() throws Exception{
	
	//compute possible split points
	double[] splitPoints = new double[m_numInstances];
	int numSplitPoints = 0;

	Instances sortedData = new Instances(m_data);
	sortedData.sort(sortedData.attribute(m_attIndex));
	
	double last, current;

	last = sortedData.instance(0).value(m_attIndex);	

	for (int i = 0; i < m_numInstances - 1; i++) {
	    current = sortedData.instance(i+1).value(m_attIndex);	
	    if (!Utils.eq(current, last)){
		splitPoints[numSplitPoints++] = (last + current) / 2.0;
	    }
	    last = current;
	}
	
	//compute entropy for all split points
	double[] entropyGain = new double[numSplitPoints];
	
	for (int i = 0; i < numSplitPoints; i++) {
	    m_splitPoint = splitPoints[i];
	    entropyGain[i] = entropyGain();
	}

	//get best entropy gain
	int bestSplit = -1;
	double bestGain = -Double.MAX_VALUE;

	for (int i = 0; i < numSplitPoints; i++) {
	    if (entropyGain[i] > bestGain) {
		bestGain = entropyGain[i];
		bestSplit = i;
	    }
	}

	if (bestSplit < 0) return false;
	
	m_splitPoint = splitPoints[bestSplit];	
	return true;
    }
    
    /**
     * Computes entropy gain for current split.
     */
    public double entropyGain() throws Exception{
	
	int numSubsets;
	if (m_attribute.isNominal()) {
	    numSubsets = m_attribute.numValues();
	} else {
	    numSubsets = 2;
	}

	double[][][] splitDataZs = new double[numSubsets][][];
	double[][][] splitDataWs = new double[numSubsets][][];
	
	//determine size of the subsets
	int[] subsetSize = new int[numSubsets];
	for (int i = 0; i < m_numInstances; i++) {
	    int subset = whichSubset(m_data.instance(i));
	    if (subset < 0) throw new Exception("ResidualSplit: no support for splits on missing values");
	    subsetSize[subset]++;
	}
	
	for (int i = 0; i < numSubsets; i++) {
	    splitDataZs[i] = new double[subsetSize[i]][];
	    splitDataWs[i] = new double[subsetSize[i]][];
	}


	int[] subsetCount = new int[numSubsets];

	//sort Zs/Ws into subsets
	for (int i = 0; i < m_numInstances; i++) {
	    int subset = whichSubset(m_data.instance(i));
	    splitDataZs[subset][subsetCount[subset]] = m_dataZs[i];
	    splitDataWs[subset][subsetCount[subset]] = m_dataWs[i];
	    subsetCount[subset]++;
	}

	//calculate entropy gain
	double entropyOrig = entropy(m_dataZs, m_dataWs);
	
	double entropySplit = 0.0;
	
	for (int i = 0; i < numSubsets; i++) {
	    entropySplit += entropy(splitDataZs[i], splitDataWs[i]);
	}
	
	return entropyOrig - entropySplit;
    }
    
    /**
     * Helper function to compute entropy from Z/W values.
     */
    protected double entropy(double[][] dataZs, double[][] dataWs){
	//method returns entropy * sumOfWeights
	double entropy = 0.0;
	int numInstances = dataZs.length;

	for (int j = 0; j < m_numClasses; j++) {
	    
	    //compute mean for class
	    double m = 0.0;
	    double sum = 0.0;
	    for (int i = 0; i < numInstances; i++) {
		m += dataZs[i][j] * dataWs[i][j];
		sum += dataWs[i][j];
	    }
	    m /= sum;

	    //sum up entropy for class
	    for (int i = 0; i < numInstances; i++) {
		entropy += dataWs[i][j] * Math.pow(dataZs[i][j] - m,2);
	    }

	}

	return entropy;
    }
    
    /**
     * Checks if there are at least 2 subsets that contain >= minNumInstances.
     */
    public boolean checkModel(int minNumInstances){
	//checks if there are at least 2 subsets that contain >= minNumInstances
	int count = 0;
	for (int i = 0; i < m_distribution.numBags(); i++) {
	    if (m_distribution.perBag(i) >= minNumInstances) count++; 
	}
	return (count >= 2);
    }

    /**
     * Returns name of splitting attribute (left side of condition).
     */
    public final String leftSide(Instances data) {
	
	return data.attribute(m_attIndex).name();
    }
    
    /**
     * Prints the condition satisfied by instances in a subset.
     */
    public final String rightSide(int index,Instances data) {
	
	StringBuffer text;
	
	text = new StringBuffer();
	if (data.attribute(m_attIndex).isNominal())
	    text.append(" = "+
			data.attribute(m_attIndex).value(index));
	else
	    if (index == 0)
		text.append(" <= "+
			    Utils.doubleToString(m_splitPoint,6));
	    else
		text.append(" > "+
			    Utils.doubleToString(m_splitPoint,6));
	return text.toString();
    }

    public final int whichSubset(Instance instance) 
	throws Exception {
	
	if (instance.isMissing(m_attIndex))
	    return -1;
	else{
	    if (instance.attribute(m_attIndex).isNominal())
		return (int)instance.value(m_attIndex);
	    else
		if (Utils.smOrEq(instance.value(m_attIndex),m_splitPoint))
		    return 0;
		else
		    return 1;
	}
    }    

    /** Method not in use*/
    public void buildClassifier(Instances data) {
	//method not in use
    }

    /**Method not in use*/
    public final double [] weights(Instance instance){
	//method not in use
	return null;
    } 
     
    /**Method not in use*/
    public final String sourceExpression(int index, Instances data) {
	//method not in use
	return "";
    }
    
}






